import java.util.HashMap;

public class MapSum {

    HashMap<String, Integer> map;
    HashMap<String, Integer> set;

    public MapSum() {
        map = new HashMap<>();
        set = new HashMap<>();
    }

    public void insert(String key, int val) {
        int delta = val - set.getOrDefault(key, 0);
        for (int i = 0; i < key.length(); i++) {
            String temp = key.substring(0, i + 1);
            map.put(temp, map.getOrDefault(temp, 0) + delta);
        }
        set.put(key, val);
    }

    public int sum(String prefix) {
        return map.getOrDefault(prefix, 0);
    }
}
